{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OXYgXFeMgRep"
      },
      "source": [
        "Copyright 2019 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NcIzzCADklYm"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/google-research/google-research.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ngihcW7ckrDI"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "import tarfile\n",
        "import urllib\n",
        "import zipfile\n",
        "sys.path.append('./google-research')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y55h79H3XKSt"
      },
      "source": [
        "# Examples of streaming and non streaming inference with TF/TFlite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fathHzuEgx8_"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yP5WBy5O8Za8"
      },
      "outputs": [],
      "source": [
        "# TF streaming\n",
        "from kws_streaming.models import models\n",
        "from kws_streaming.models import utils\n",
        "from kws_streaming.models import model_utils\n",
        "from kws_streaming.layers.modes import Modes\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wsUCmBzpk1jC"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import tensorflow.compat.v1 as tf1\n",
        "import logging\n",
        "from kws_streaming.models import model_flags\n",
        "from kws_streaming.models import model_params\n",
        "from kws_streaming.train import inference\n",
        "from kws_streaming.train import test\n",
        "from kws_streaming.data import input_data\n",
        "from kws_streaming.data import input_data_utils as du\n",
        "tf1.disable_eager_execution()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jow_HMLAU7LR"
      },
      "outputs": [],
      "source": [
        "config = tf1.ConfigProto()\n",
        "config.gpu_options.allow_growth = True\n",
        "sess = tf1.Session(config=config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zMdTK10tL2Dz"
      },
      "outputs": [],
      "source": [
        "# general imports\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import json\n",
        "import numpy as np\n",
        "import scipy as scipy\n",
        "import scipy.io.wavfile as wav\n",
        "import scipy.signal"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L_F-8OFCU7La"
      },
      "outputs": [],
      "source": [
        "tf.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xHTcbg_ao586"
      },
      "outputs": [],
      "source": [
        "tf1.reset_default_graph()\n",
        "sess = tf1.Session()\n",
        "tf1.keras.backend.set_session(sess)\n",
        "tf1.keras.backend.set_learning_phase(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylPGCTPLh41F"
      },
      "source": [
        "## Load wav file"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b8Bvq7XacsOu"
      },
      "outputs": [],
      "source": [
        "def waveread_as_pcm16(filename):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  samplerate, wave_data = wav.read(filename)\n",
        "  # Read in wav file.\n",
        "  return wave_data, samplerate\n",
        "\n",
        "def wavread_as_float(filename, target_sample_rate=16000):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  wave_data, samplerate = waveread_as_pcm16(filename)\n",
        "  desired_length = int(\n",
        "      round(float(len(wave_data)) / samplerate * target_sample_rate))\n",
        "  wave_data = scipy.signal.resample(wave_data, desired_length)\n",
        "\n",
        "  # Normalize short ints to floats in range [-1..1).\n",
        "  data = np.array(wave_data, np.float32) / 32768.0\n",
        "  return data, target_sample_rate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e6MDIFztU7Lp"
      },
      "outputs": [],
      "source": [
        "# set PATH to data sets (for example to speech commands V2):\n",
        "# it can be downloaded from\n",
        "# https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz\n",
        "# if you run 00_check-data.ipynb then data2 should be located in the current folder\n",
        "current_dir = os.getcwd()\n",
        "DATA_PATH = os.path.join(current_dir, \"data2/\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TYj0JGeHhtqc"
      },
      "outputs": [],
      "source": [
        "# Set path to wav file for testing.\n",
        "wav_file = os.path.join(DATA_PATH, \"left/012187a4_nohash_0.wav\")\n",
        "\n",
        "# read audio file\n",
        "wav_data, samplerate = wavread_as_float(wav_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "00CBA81RU7Lz"
      },
      "outputs": [],
      "source": [
        "assert samplerate == 16000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r2yeKkLsiRWJ"
      },
      "outputs": [],
      "source": [
        "plt.plot(wav_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5_wbAZ3vhQh1"
      },
      "source": [
        "## Prepare batched model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SYl2VSAhU7L_"
      },
      "outputs": [],
      "source": [
        "# This notebook is configured to work with 'ds_tc_resnet' and 'svdf'.\n",
        "MODEL_NAME = 'ds_tc_resnet'\n",
        "# MODEL_NAME = 'svdf'\n",
        "MODELS_PATH = os.path.join(current_dir, \"models\")\n",
        "MODEL_PATH = os.path.join(MODELS_PATH, MODEL_NAME + \"/\")\n",
        "MODEL_PATH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vPpyvR66URVe"
      },
      "outputs": [],
      "source": [
        "train_dir = os.path.join(MODELS_PATH, MODEL_NAME)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mH-fyuESU7MD"
      },
      "outputs": [],
      "source": [
        "# below is another way of reading flags - through json\n",
        "with tf.compat.v1.gfile.Open(os.path.join(train_dir, 'flags.json'), 'r') as fd:\n",
        "  flags_json = json.load(fd)\n",
        "\n",
        "class DictStruct(object):\n",
        "  def __init__(self, **entries):\n",
        "    self.__dict__.update(entries)\n",
        "\n",
        "flags = DictStruct(**flags_json)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T-lmp__rUVFz"
      },
      "outputs": [],
      "source": [
        "flags.data_dir = DATA_PATH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PSfkGIgAU7MO"
      },
      "outputs": [],
      "source": [
        "# get total stride of the model\r\n",
        "\r\n",
        "total_stride = 1\r\n",
        "if MODEL_NAME == 'ds_tc_resnet':\r\n",
        "  # it can be automated by scanning layers of the model, but for now just use parameters of specific model\r\n",
        "  pools = model_utils.parse(flags.ds_pool)\r\n",
        "  strides = model_utils.parse(flags.ds_stride)\r\n",
        "  time_stride = [1]\r\n",
        "  for pool in pools:\r\n",
        "    if pool \u003e 1:\r\n",
        "      time_stride.append(pool)\r\n",
        "  for stride in strides:\r\n",
        "    if stride \u003e 1:\r\n",
        "      time_stride.append(stride)\r\n",
        "  total_stride = np.prod(time_stride)\r\n",
        "\r\n",
        "# overide input data shape for streaming model with stride/pool\r\n",
        "flags.data_stride = total_stride\r\n",
        "flags.data_shape = (total_stride * flags.window_stride_samples,)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gkWTgtOWU7MT"
      },
      "outputs": [],
      "source": [
        "# prepare mapping of index to word\n",
        "audio_processor = input_data.AudioProcessor(flags)\n",
        "index_to_label = {}\n",
        "# labels used for training\n",
        "for word in audio_processor.word_to_index.keys():\n",
        "  if audio_processor.word_to_index[word] == du.SILENCE_INDEX:\n",
        "    index_to_label[audio_processor.word_to_index[word]] = du.SILENCE_LABEL\n",
        "  elif audio_processor.word_to_index[word] == du.UNKNOWN_WORD_INDEX:\n",
        "    index_to_label[audio_processor.word_to_index[word]] = du.UNKNOWN_WORD_LABEL\n",
        "  else:\n",
        "    index_to_label[audio_processor.word_to_index[word]] = word\n",
        "\n",
        "# training labels\n",
        "index_to_label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U1al1r1PU7MR"
      },
      "outputs": [],
      "source": [
        "# pad input audio with zeros, so that audio len = flags.desired_samples\n",
        "padded_wav = np.pad(wav_data, (0, flags.desired_samples-len(wav_data)), 'constant')\n",
        "\n",
        "input_data = np.expand_dims(padded_wav, 0)\n",
        "input_data.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wsGDG4A0cIMO"
      },
      "outputs": [],
      "source": [
        "# create model with flag's parameters\n",
        "model_non_stream_batch = models.MODELS[flags.model_name](flags)\n",
        "\n",
        "# load model's weights\n",
        "weights_name = 'best_weights'\n",
        "model_non_stream_batch.load_weights(os.path.join(train_dir, weights_name))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QVhESthmMl0X"
      },
      "outputs": [],
      "source": [
        "tf.keras.utils.plot_model(\n",
        "    model_non_stream_batch,\n",
        "    show_shapes=True,\n",
        "    show_layer_names=True,\n",
        "    expand_nested=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RIr1DWLisMu9"
      },
      "source": [
        "## Run inference with TF"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "456ynjRxmdVc"
      },
      "source": [
        "### TF Run non streaming inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-vJpOCJClDK5"
      },
      "outputs": [],
      "source": [
        "# convert model to inference mode with batch one\n",
        "inference_batch_size = 1\n",
        "tf.keras.backend.set_learning_phase(0)\n",
        "flags.batch_size = inference_batch_size  # set batch size\n",
        "\n",
        "model_non_stream = utils.to_streaming_inference(model_non_stream_batch, flags, Modes.NON_STREAM_INFERENCE)\n",
        "#model_non_stream.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O1gOGQjWMufh"
      },
      "outputs": [],
      "source": [
        "tf.keras.utils.plot_model(\n",
        "    model_non_stream,\n",
        "    show_shapes=True,\n",
        "    show_layer_names=True,\n",
        "    expand_nested=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nPUfT4a4lxIj"
      },
      "outputs": [],
      "source": [
        "predictions = model_non_stream.predict(input_data)\n",
        "predicted_labels = np.argmax(predictions, axis=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "63sisD1hl7jz"
      },
      "outputs": [],
      "source": [
        "predicted_labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rBhLA1OZmQxj"
      },
      "outputs": [],
      "source": [
        "index_to_label[predicted_labels[0]]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZVFoVdYSpnL_"
      },
      "source": [
        "### TF Run streaming inference with internal state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cgcpcrASquAY"
      },
      "outputs": [],
      "source": [
        "# convert model to streaming mode\n",
        "flags.batch_size = inference_batch_size  # set batch size\n",
        "\n",
        "model_stream = utils.to_streaming_inference(model_non_stream_batch, flags, Modes.STREAM_INTERNAL_STATE_INFERENCE)\n",
        "#model_stream.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BNtgTOBCM06v"
      },
      "outputs": [],
      "source": [
        "tf.keras.utils.plot_model(\n",
        "    model_stream,\n",
        "    show_shapes=True,\n",
        "    show_layer_names=True,\n",
        "    expand_nested=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7NOG8wrYpnnq"
      },
      "outputs": [],
      "source": [
        "stream_output_prediction = inference.run_stream_inference_classification(flags, model_stream, input_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S-xeXPhAqC20"
      },
      "outputs": [],
      "source": [
        "stream_output_arg = np.argmax(stream_output_prediction)\r\n",
        "stream_output_arg"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FlM0ppVIU0rW"
      },
      "outputs": [],
      "source": [
        "index_to_label[stream_output_arg]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F5WYgOtSqrQb"
      },
      "source": [
        "### TF Run streaming inference with external state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2hTLEY1qq_ig"
      },
      "outputs": [],
      "source": [
        "# convert model to streaming mode\n",
        "flags.batch_size = inference_batch_size  # set batch size\n",
        "\n",
        "model_stream_external = utils.to_streaming_inference(model_non_stream_batch, flags, Modes.STREAM_EXTERNAL_STATE_INFERENCE)\n",
        "#model_stream.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AyeABeg9Mbf6"
      },
      "outputs": [],
      "source": [
        "tf.keras.utils.plot_model(\n",
        "    model_stream_external,\n",
        "    show_shapes=True,\n",
        "    show_layer_names=True,\n",
        "    expand_nested=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RISdLTnmqrcA"
      },
      "outputs": [],
      "source": [
        "inputs = []\n",
        "for s in range(len(model_stream_external.inputs)):\n",
        "  inputs.append(np.zeros(model_stream_external.inputs[s].shape, dtype=np.float32))\n",
        "\n",
        "window_stride = flags.data_shape[0]\n",
        "\n",
        "start = 0\n",
        "end = window_stride\n",
        "while end \u003c= input_data.shape[1]:\n",
        "  # get new frame from stream of data\n",
        "  stream_update = input_data[:, start:end]\n",
        "\n",
        "  # update indexes of streamed updates\n",
        "  start = end\n",
        "  end = start + window_stride\n",
        "\n",
        "  # set input audio data (by default input data at index 0)\n",
        "  inputs[0] = stream_update\n",
        "\n",
        "  # run inference\n",
        "  outputs = model_stream_external.predict(inputs)\n",
        "\n",
        "  # get output states and set it back to input states\n",
        "  # which will be fed in the next inference cycle\n",
        "  for s in range(1, len(model_stream_external.inputs)):\n",
        "    inputs[s] = outputs[s]\n",
        "\n",
        "  stream_output_arg = np.argmax(outputs[0])\n",
        "stream_output_arg"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u6p1xubwrYyo"
      },
      "outputs": [],
      "source": [
        "index_to_label[stream_output_arg]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KAJs5dBXsYCa"
      },
      "source": [
        "## Run inference with TFlite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z5qmO5KrU7NP"
      },
      "source": [
        "### Run non streaming inference with TFLite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "88bclN4rtu-5"
      },
      "outputs": [],
      "source": [
        "tflite_non_streaming_model = utils.model_to_tflite(sess, model_non_stream_batch, flags, Modes.NON_STREAM_INFERENCE)\n",
        "tflite_non_stream_fname = 'tflite_non_stream.tflite'\n",
        "with open(os.path.join(MODEL_PATH, tflite_non_stream_fname), 'wb') as fd:\n",
        "  fd.write(tflite_non_streaming_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VZgH11_0u2ZN"
      },
      "outputs": [],
      "source": [
        "interpreter = tf.lite.Interpreter(model_content=tflite_non_streaming_model)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3J2n7VB5JxV6"
      },
      "outputs": [],
      "source": [
        "# set input audio data (by default input data at index 0)\n",
        "interpreter.set_tensor(input_details[0]['index'], input_data.astype(np.float32))\n",
        "\n",
        "# run inference\n",
        "interpreter.invoke()\n",
        "\n",
        "# get output: classification\n",
        "out_tflite = interpreter.get_tensor(output_details[0]['index'])\n",
        "\n",
        "out_tflite_argmax = np.argmax(out_tflite)\n",
        "\n",
        "out_tflite_argmax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TXqHxLcVregL"
      },
      "outputs": [],
      "source": [
        "index_to_label[out_tflite_argmax]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xNaUWgivuatL"
      },
      "source": [
        "### Run streaming inference with TFLite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "csQWZo4BuqEB"
      },
      "outputs": [],
      "source": [
        "tflite_streaming_model = utils.model_to_tflite(sess, model_non_stream_batch, flags, Modes.STREAM_EXTERNAL_STATE_INFERENCE)\n",
        "tflite_stream_fname = 'tflite_stream.tflite'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a4wAZqYouyob"
      },
      "outputs": [],
      "source": [
        "with open(os.path.join(MODEL_PATH, tflite_stream_fname), 'wb') as fd:\n",
        "  fd.write(tflite_streaming_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "03QCq1nfVUWW"
      },
      "outputs": [],
      "source": [
        "interpreter = tf.lite.Interpreter(model_content=tflite_streaming_model)\r\n",
        "interpreter.allocate_tensors()\r\n",
        "\r\n",
        "input_details = interpreter.get_input_details()\r\n",
        "output_details = interpreter.get_output_details()\r\n",
        "\r\n",
        "input_states = []\r\n",
        "for s in range(len(input_details)):\r\n",
        "  input_states.append(np.zeros(input_details[s]['shape'], dtype=np.float32))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WKudF1Zyud2-"
      },
      "outputs": [],
      "source": [
        "out_tflite = inference.run_stream_inference_classification_tflite(flags, interpreter, input_data, input_states)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yWy_BiepFFSX"
      },
      "outputs": [],
      "source": [
        "out_tflite_argmax = np.argmax(out_tflite[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QSa7AX1GvReF"
      },
      "outputs": [],
      "source": [
        "index_to_label[out_tflite_argmax]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NMlGC6anVX_s"
      },
      "source": [
        "## Run evaluation on all testing data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WVQZXLyDU7N3"
      },
      "outputs": [],
      "source": [
        "test.tflite_non_stream_model_accuracy(\r\n",
        "    flags,\r\n",
        "    MODEL_PATH,\r\n",
        "    tflite_model_name=tflite_non_stream_fname,\r\n",
        "    accuracy_name='tflite_non_stream_model_accuracy.txt')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uoh8wIATVO40"
      },
      "outputs": [],
      "source": [
        "test.tflite_stream_state_external_model_accuracy(\r\n",
        "    flags,\r\n",
        "    MODEL_PATH,\r\n",
        "    tflite_model_name=tflite_stream_fname,\r\n",
        "    accuracy_name='tflite_stream_state_external_model_accuracy.txt',\r\n",
        "    reset_state=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "crtqbyWaVQEX"
      },
      "outputs": [],
      "source": [
        "test.tflite_stream_state_external_model_accuracy(\r\n",
        "    flags,\r\n",
        "    MODEL_PATH,\r\n",
        "    tflite_model_name=tflite_stream_fname,\r\n",
        "    accuracy_name='tflite_stream_state_external_model_accuracy.txt',\r\n",
        "    reset_state=False)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "02_inference.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
